import numpy as np
import torch
import torch.nn as nn

class CatGPU(nn.Module):
    def __init__(self, X, Y, param):
        super().__init__()
        self.param = param
        self.Xnoise_scale = 0.01
        self.device = param['device']
        self.Yinit = param.get('width_init')
        self.Xlength = self.kernel_init(X)
        self.Ylength = self.kernel_init(Y, self.Yinit)
        self.Xlength_scale_ = nn.Parameter(torch.tensor(np.log(self.Xlength).astype(float), device=self.device))
        self.Ylength_scale_ = nn.Parameter(torch.tensor(np.log(self.Ylength).astype(float), device=self.device))

    @property
    def Xlength_scale(self):
        return 2*torch.exp(self.Xlength_scale_)

    @property
    def Ylength_scale(self):
        return 2*torch.exp(self.Ylength_scale_)

    def kernel_init(self, X, weight=None):
        if weight is not None:
            return torch.tensor(weight).numpy()
        sq = (X ** 2).sum(dim=1, keepdim=True)
        sqdist = sq + sq.T - 2 *X.mm(X.T)
        if (sqdist.max() == 0):
            return 1
        dists = (sqdist - torch.tril(sqdist, diagonal=0)).flatten()
        mid = torch.median(dists[dists>0])
        width = torch.sqrt(0.5 * mid)
        length = width * 1
        return length.detach().numpy()

    def Xkernel_mat_self(self, X):
        sq = (X ** 2).sum(dim=1, keepdim=True)
        sqdist = sq + sq.T - 2 *X.mm(X.T)
        return torch.exp(- 0.5 * sqdist / self.Xlength_scale) + self.Xnoise_scale * torch.eye(len(X)).to(self.device)

    def Ykernel_mat_self(self, Y):
        sq = (Y ** 2).sum(dim=1, keepdim=True)
        sqdist = sq + sq.T - 2 *Y.mm(Y.T)
        return torch.exp(- 0.5 * sqdist / self.Ylength_scale)

    def ConditionalLikelihood(self, X, Y):
        n = X.shape[0]
        KXX = self.Xkernel_mat_self(X)
        KYY = self.Ykernel_mat_self(Y)
        L = torch.linalg.cholesky(KXX)
        alpha = torch.linalg.solve(L.T, torch.linalg.solve(L, KYY))
        marginal_likelihood = -0.5*torch.trace(KYY.mm(alpha)) - n*(torch.log(torch.diag(L)).sum() - n * 0.5 * np.log(2 * np.pi))
        self.KX = KXX; self.KY = KYY; self.Y = Y; self.X = X; self.L = L; self.KY = KYY;  self.alpha = alpha
        return marginal_likelihood / n

    def Jacobian(self, y):
        """
        determinent of Jabobian matrix term
        param y: test input data point. N x 1.
        """
        v = y.shape[-1]
        Y = self.Y
        n = y.shape[0]
        N = Y.shape[0]

        KyY = self.Ykernel_mat(y, Y).expand(v, n, N)
        ym = y.T.unsqueeze(2)
        Ym = Y.T.unsqueeze(1)
        dist = (ym - Ym)
        driv_KyY = - torch.mul(dist, KyY) / self.Ylength_scale
        driv_KyY_diag  = driv_KyY[0].T.flatten()[:-1].view(n-1, n+1)[:, 1:].flatten().view(n, n-1).T
        abs_deriv_diag = torch.abs(driv_KyY_diag)
        nonZerosIdx = abs_deriv_diag > 0
        log_deriv_diag = torch.log(abs_deriv_diag[nonZerosIdx]+1e-15)
        driv_score = torch.sum(log_deriv_diag)
        return driv_score / n


    def Score(self, x, y):
        """
        score function using marginal likelihood of joint distribution
        :param x: independent variable. N x m
        :param y: dependent variable. N x 1
        :return:
        """
        n = x.shape[0]
        x = x.float().to(self.device)
        y = y.float().to(self.device)
        marginal_likelihood = self.ConditionalLikelihood(x, y)/n
        driv_term = self.Jacobian(y)/n
        nll = - (marginal_likelihood + driv_term)
        return nll, marginal_likelihood, driv_term

    def Xkernel_mat(self, X, Z):
        Xsq = (X ** 2).sum(dim=1, keepdim=True)
        Zsq = (Z ** 2).sum(dim=1, keepdim=True)
        sqdist = Xsq + Zsq.T - 2 * X.mm(Z.T)
        return torch.exp(- 0.5 * sqdist / self.Xlength_scale)

    def Ykernel_mat(self, y, Y):
        ysq = (y ** 2).sum(dim=1, keepdim=True)
        Ysq = (Y ** 2).sum(dim=1, keepdim=True)
        sqdist = ysq + Ysq.T - 2 * y.mm(Y.T)
        return torch.exp(- 0.5 * sqdist / self.Ylength_scale)

    def predict(self, x):
        """compute prediction. fit() must have been called.
        x: test input data point. N x D tensor for the data dimensionality D."""
        x = x.float().to(self.device)
        alpha = self.alpha
        kxX = self.Xkernel_mat(x, self.X)
        mu = kxX.mm(alpha)
        return mu

    def calculateKyY(self, y):
        """
        Cauculating ky during inference
        :param y:
        :return:
        """
        y = y.float().to(self.device)
        KyY = self.Ykernel_mat(y, self.Y)
        return KyY


    def train_step(self, x, y, opt):
        opt.zero_grad()
        Score, nlml, Jacobian = self.Score(x, y)

        Score.backward()
        opt.step()
        return {'score': Score.item(),
                'nlml': nlml.detach().cpu(),
                'driv': Jacobian.detach().cpu(),
                'Xlength': torch.sqrt(self.Xlength_scale).detach().cpu(),
                'Ylength': torch.sqrt(self.Ylength_scale).detach().cpu(),
                }





